import torch
from torch.utils.data import Dataset, DataLoader
import os

data_path = "/home/chonghao/xue_research/supplementary/data/irradiation_v3"
filename_pkl = "all_data.pkl"
video_pkl = "synthesized_eta.pkl"

all_data = torch.load(os.path.join(data_path, filename_pkl))
video_data = torch.load(os.path.join(data_path, video_pkl))

class Container(torch.nn.Module):
    __constants__ = ['cv', 'ci', 'eta', 'video']
    def __init__(self, my_values):
        super(Container, self).__init__()
        cv_tensor = torch.Tensor(len(my_values['cv']), 130, 130)
        ci_tensor = torch.Tensor(len(my_values['ci']), 130, 130)
        eta_tensor = torch.Tensor(len(my_values['eta']), 130, 130)
        torch.cat(my_values['cv'], out=cv_tensor)
        torch.cat(my_values['ci'], out=ci_tensor)
        torch.cat(my_values['eta'], out=eta_tensor)
        self.cv = cv_tensor
        self.ci = ci_tensor
        self.eta = eta_tensor
        self.video = my_values['video']
    
    # @torch.jit.script_method
    def get_cv(self):
        return self.cv

    # @torch.jit.script_method
    def get_ci(self):
        return self.ci

    # @torch.jit.script_method
    def get_eta(self):
        return self.eta

    # @torch.jit.script_method
    def get_video(self):
        return self.video

    

my_values = {
    # 'cv': torch.FloatTensor(all_data[:]['cv'][:, :]),
    'cv' : [all_data[i]['cv'].unsqueeze_(0) for i in range(len(all_data))],
    # 'ci': torch.FloatTensor(all_data[:]['ci'][:, :]),
    # 'eta': torch.FloatTensor(all_data[:]['eta'][:, :]),
    'ci' : [all_data[i]['ci'].unsqueeze_(0) for i in range(len(all_data))],
    'eta' : [all_data[i]['eta'].unsqueeze_(0) for i in range(len(all_data))],
    'video': video_data[:, :, :, :]
}

# my_values_2 = {
#     'video': video_data[:, :, :, :]
# }

container = Container(my_values)
# container = torch.jit.script(Container(my_values))
# container.save(os.path.join(data_path, "container_all_data.pt"))

# container_2 = torch.jit.script(Container(my_values_2))
# container_2.save(os.path.join(data_path, "container_video_data.pt"))

# TODO: directly write data to file and pickle_load in cpp

# # save cv
# import io 
# f = io.BytesIO()
# torch.save(container.get_cv(), f, _use_new_zipfile_serialization=True)
# with open(os.path.join(data_path, "cv_all_data.pt"), "wb") as outfile:
#     outfile.write(f.getbuffer())

# # save ci
# import io 
# f = io.BytesIO()
# torch.save(container.get_ci(), f, _use_new_zipfile_serialization=True)
# with open(os.path.join(data_path, "ci_all_data.pt"), "wb") as outfile:
#     outfile.write(f.getbuffer())

# # save eta
# import io 
# f = io.BytesIO()
# torch.save(container.get_eta(), f, _use_new_zipfile_serialization=True)
# with open(os.path.join(data_path, "eta_all_data.pt"), "wb") as outfile:
#     outfile.write(f.getbuffer())

# # save video
# import io 
# f = io.BytesIO()
# torch.save(container.get_video(), f, _use_new_zipfile_serialization=True)
# with open(os.path.join(data_path, "video_all_data.pt"), "wb") as outfile:
#     outfile.write(f.getbuffer())

# test cv
# test ci
# test eta
# test video